import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
import MetaTrader5 as mt5
import matplotlib.pyplot as plt
import joblib

# Connect to MetaTrader 5
if not mt5.initialize():
    print("Initialize failed")
    mt5.shutdown()

# Load market data
symbol = "EURUSD"
timeframe = mt5.TIMEFRAME_M15
rates = mt5.copy_rates_from_pos(symbol, timeframe, 0, 96)  # Note: 96 represents 1 day or 15*96= 1440 minutes of data (there are 1440 minutes in a day)
mt5.shutdown()

# Convert to DataFrame
data = pd.DataFrame(rates)
data['time'] = pd.to_datetime(data['time'], unit='s')
data.set_index('time', inplace=True)

# Tokenize time
data['time_token'] = (data.index.hour * 3600 + data.index.minute * 60 + data.index.second) / 86400

# Normalize prices on a rolling basis resetting at the start of each day
def normalize_daily_rolling(data):
    data['date'] = data.index.date
    data['rolling_high'] = data.groupby('date')['high'].transform(lambda x: x.expanding(min_periods=1).max())
    data['rolling_low'] = data.groupby('date')['low'].transform(lambda x: x.expanding(min_periods=1).min())

    data['norm_open'] = (data['open'] - data['rolling_low']) / (data['rolling_high'] - data['rolling_low'])
    data['norm_high'] = (data['high'] - data['rolling_low']) / (data['rolling_high'] - data['rolling_low'])
    data['norm_low'] = (data['low'] - data['rolling_low']) / (data['rolling_high'] - data['rolling_low'])
    data['norm_close'] = (data['close'] - data['rolling_low']) / (data['rolling_high'] - data['rolling_low'])

    # Replace NaNs with zeros
    data.fillna(0, inplace=True)
    return data

# Visualize the price before normalization
plt.figure(figsize=(15, 10))

plt.subplot(3, 1, 1)
data['close'].plot()
plt.title('Close Prices')
plt.xlabel('Time')
plt.ylabel('Price')

data = normalize_daily_rolling(data)

# Check for NaNs in the data
if data.isnull().values.any():
    print("Data contains NaNs")
    print(data.isnull().sum())

# Drop unnecessary columns
data = data[['time_token', 'norm_open', 'norm_high', 'norm_low', 'norm_close']]

# Visualize the normalized price
plt.subplot(3, 1, 2)
data['norm_close'].plot()
plt.title('Normalized Close Prices')
plt.xlabel('Time')
plt.ylabel('Normalized Price')

# Visualize Time After Tokenization
plt.subplot(3, 1, 3)
data['time_token'].plot()
plt.title('Time Token')
plt.xlabel('Time')
plt.ylabel('Time Token')

plt.tight_layout()
plt.show()